# Import necessary libraries
import pandas as pd
import random
import time
from tqdm import tqdm
from vllm import LLM, SamplingParams
import os
import transformers
import json
dashed_line = '-'*50


model_name = '/home/models/deepseek-llm-7b-chat'
pipeline = LLM(model=model_name)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
sampling_params = SamplingParams(temperature=0.0, top_p=1,top_k=-1, n=1, max_tokens=1024)

def format_inputs(prompts):
    formatted_inputs = []
    for prompt in prompts:
        formatted_inputs.append(
            [{"role": "user", "content": f"{prompt}"}]
        )
    return formatted_inputs

def postprocess_aya_output(output):
    return [response[0]['generated_text'][1]['content'] for response in output]

def llm_inference(prompt):
    prompts = [prompt]
    prompts = format_inputs(prompts)
    formatted_prompts = tokenizer.apply_chat_template(prompts, tokenize=False, add_generation_prompt=True)
    predictions = pipeline.generate(formatted_prompts, sampling_params)
    prediction = [output.outputs[0].text for output in predictions][0]
    return prediction
    

def sample_allowed_document_ids(df, col="News ID"):
    ids = df[col].values.tolist()
    allowed_range = range(10000, 30000)
    new_ids = [id for id in ids if float(str(id).replace('N','')) in allowed_range]
    df = df[df[col].isin(new_ids)]
    df = df.reset_index()
    return df

def sample_disallowed_document_ids(df, col="News ID"):
    ids = df[col].values.tolist()
    allowed_range = range(10000, 30000)
    new_ids = [id for id in ids if float(str(id).replace('N','')) not in allowed_range]
    df = df[df[col].isin(new_ids)]
    df = df.reset_index()
    return df


def generate_fewshot_examples(df, num_examples=4):
    """
    Randomly samples num_examples rows from the dataframe and formats them 
    into the required few-shot example structure.
    
    Args:
    - df (pd.DataFrame): DataFrame containing columns 'News ID', 'News body', and 'Headline'.
    - num_examples (int): Number of examples to sample (default is 4).

    Returns:
    - str: A formatted string containing the few-shot examples.
    """
    # Sample random rows from the dataframe
    sampled_df = df.sample(n=min(num_examples, len(df)), random_state=random.randint(1, 10000))

    # Format the few-shot examples
    fewshot_examples = "\n\n".join(
        f"Document: {row['News ID']}\n"
        f"Document Content: {row['News body']}\n"
        f"Headline: {row['Headline']}"
        for _, row in sampled_df.iterrows()
    )

    return fewshot_examples


# Function 2: Generate Action for Each Document
def generate_action_for_docs(interaction_output):
    try:
        interaction_dict = json.loads(interaction_output)
        docs, actions = list(interaction_dict.keys()), list(interaction_dict.values())
        return docs, actions

    except:
        print(interaction_output)
        raise Exception
    
# Function 6: Aggregate trajectories and summaries
def aggregate_results(user_trajectories, summaries):
    trajectory_dataset = pd.DataFrame(user_trajectories)
    summary_dataset = pd.DataFrame(summaries)
    return trajectory_dataset, summary_dataset

def generate_personalized_summary(user_id, doc_id, headline):
    summary_id = f"S{random.randint(1, 20000)}"
    return {
        "Summary ID": summary_id,
        "User ID": user_id,
        "News ID": doc_id,
        "Expected Headline": headline
    }

# Function 5: Combine user trajectory details
def combine_user_trajectory(user_id, docs, actions):
    return {
        "User ID": user_id,
        "Sequence of Docs": docs,
        "Sequence of Actions": actions,
        "Number of Summary Nodes": actions.count("sumgen")
    }  


# Function 6: Aggregate trajectories and summaries
def aggregate_results(user_trajectories, summaries):
    trajectory_dataset = pd.DataFrame(user_trajectories)
    summary_dataset = pd.DataFrame(summaries)
    return trajectory_dataset, summary_dataset


prompt_template_click_skip_sequence = lambda user_id, doc_sequence: f"""
Generate a sequence of interactions for User {user_id}.
Document IDs: {doc_sequence}
Actions: [click, skip]
Rules:
- Each action corresponds to a document in the sequence.
- The sequence should only contain "click" or "skip" actions at this stage.

Please format your output as a Json object {{document1 : action1, document2 : action2 }}.
Output must only contain the Json object. 
""".strip()


prompt_template_handle_gensum_and_sumgen = lambda user_id, doc_id, doc_content, fewshot_examples: f"""
### Examples
Below are examples of personalized headlines generated for different users based on their document content:

{fewshot_examples}

---
### Task:
Generate a highly personalized headline based on the document content. Ensure that the headline aligns with the user's preferences and effectively captures the essence of the document.

User: {user_id}  
Document: {doc_id}  
Document Content: {doc_content}  

Strictly return a JSON object in the following format: 
{{
    "headline": "your generated headline"
}}
""".strip()


# Main function to execute the chained process
def generate_trajectories_and_summaries(news_data, context_data, num_users=100, trajectory_length=10):
 
    user_trajectories = []
    summaries = []
    progress_bar = tqdm(range(1, num_users + 1), desc="Generating User Trajectories")

    for user_id in progress_bar:
        start_time = time.time()

        # Step 1: Generate Click/Skip Sequence
        doc_sequence = list(news_data["News ID"].sample(trajectory_length, replace=True))
        # Filter only between N-10k to N-20k
        
        interaction_prompt1 = prompt_template_click_skip_sequence(user_id, doc_sequence)
        interaction_output1 = llm_inference(interaction_prompt1)

        print(f"Interaction prompt 1:\n{interaction_prompt1}")
        print(dashed_line)
        print(f"Interaction output 1:\n{interaction_output1}")
        print(dashed_line)

        # Step 2: Identify Gensum Interaction
        try:
            docs, actions = generate_action_for_docs(interaction_output1)
        except:
            continue
        
        print(f"Docs: {docs}")
        print(dashed_line)
        print(f"Actions: {actions}")
        print(dashed_line)

        actions.append("gensum")
        existing_doc_nums = {int(doc[1:]) for doc in docs}
        new_doc_num = random.choice([num for num in range(10000, 30000) if num not in existing_doc_nums])
        new_doc_id = f'N{new_doc_num}'
        docs.append(new_doc_id)
        print(f"Docs: {docs}")
        print(dashed_line)
        print(f"Actions: {actions}")
        print(dashed_line)       
        
        # Step 3 & 4: Handle gensum and Generate Personalized Summary
        for i, action in enumerate(actions):
            if action == "gensum":
                doc_id = docs[i]
                doc_content = news_data.loc[news_data["News ID"] == doc_id, "News body"].values[0]
                headline = None
                count = 3
                while headline is None and count != 0:                
                    fewshot_examples = generate_fewshot_examples(context_data,num_examples=4)
                    headline_prompt = prompt_template_handle_gensum_and_sumgen(user_id, doc_id, doc_content, fewshot_examples)
                    response = llm_inference(headline_prompt)  
                    try:
                        headline = json.loads(response)["headline"]
                        count = 0
                    except Exception as e:
                        print(e)
                        print(response)
                        headline=None
                        count -= 1
                
                print(f"Headline Prompt:\n{headline_prompt}")
                print(dashed_line)
                print(f"Generated Headline:\n{headline}")
                print(dashed_line)   
    
                summary =   generate_personalized_summary(user_id, doc_id, headline)
                summaries.append(summary)

                # Update actions and docs for sumgen
                actions.insert(i + 1, "sumgen")
                docs.insert(i + 1 , summary["Summary ID"])

        print(f"Docs: {docs}")
        print(dashed_line)
        print(f"Actions: {actions}")
        print(dashed_line)     
        print(dashed_line)     
        print(dashed_line)     

        # Step 5: Combine Outputs for Trajectory
        trajectory = combine_user_trajectory(user_id, docs, actions)
        user_trajectories.append(trajectory)
        print(f"User trajectories:\n{user_trajectories}")
        print(dashed_line)
        
        # Track time for each user
        progress_bar.set_postfix({"Time per Interaction (s)": round(time.time() - start_time, 2)})

    # Step 6: Aggregate Results
    return aggregate_results(user_trajectories, summaries)


if __name__ == '__main__':
    num_users = 20000
    
    news_data = pd.read_csv("data/news_dataset.tsv", sep="\t")  # Assuming TSV format
    print(f"Total data: {news_data.shape}")
    context_data = sample_disallowed_document_ids(news_data)
    print(f"Context data: {context_data.shape}")
    news_data = sample_allowed_document_ids(news_data)
    print(f"Target (news) data: {news_data.shape}")
    
    trajectory_dataset, summary_dataset = generate_trajectories_and_summaries(news_data, context_data, num_users)
    
    model_name = model_name.split('/')[-1]
    trajectory_dataset.to_csv(f"output/prompt_experiment/{model_name}_{num_users}_4_user_trajectories.csv", index=False)
    summary_dataset.to_csv(f"output/prompt_experiment/{model_name}_{num_users}_4_summary.csv", index=False)
